import os

import numpy as np
from tqdm import trange

from tinygrad.tensor import Tensor


def sparse_categorical_crossentropy(out, Y):
    num_classes = out.shape[-1]
    YY = Y.flatten()
    y = np.zeros((YY.shape[0], num_classes), np.float32)
    # correct loss for NLL, torch NLL loss returns one per row
    y[range(y.shape[0]), YY] = -1.0 * num_classes
    y = y.reshape(list(Y.shape) + [num_classes])
    y = Tensor(y)
    return out.mul(y).mean()


def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categorical_crossentropy):
    Tensor.training = True
    losses, accuracies = [], []
    for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
        samp = np.random.randint(0, X_train.shape[0], size=(BS))

        x = Tensor(X_train[samp])
        y = Y_train[samp]

        # network
        out = model.forward(x)

        loss = lossfn(out, y)
        optim.zero_grad()
        loss.backward()
        optim.step()

        cat = np.argmax(out.cpu().data, axis=-1)
        accuracy = (cat == y).mean()

        # printing
        loss = loss.cpu().data
        losses.append(loss)
        accuracies.append(accuracy)
        t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))


def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=False):
    Tensor.training = False

    def numpy_eval(num_classes):
        Y_test_preds_out = np.zeros(list(Y_test.shape) + [num_classes])
        for i in trange((len(Y_test) - 1) // BS + 1, disable=os.getenv('CI') is not None):
            Y_test_preds_out[i * BS:(i + 1) * BS] = model.forward(Tensor(X_test[i * BS:(i + 1) * BS])).cpu().data
        Y_test_preds = np.argmax(Y_test_preds_out, axis=-1)
        return (Y_test == Y_test_preds).mean(), Y_test_preds

    if num_classes is None: num_classes = Y_test.max().astype(int) + 1
    acc, Y_test_pred = numpy_eval(num_classes)
    print("test set accuracy is %f" % acc)
    return (acc, Y_test_pred) if return_predict else acc
